{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c79d4c25",
   "metadata": {},
   "source": [
    "# Comparing Deep Learning Models for Sales Forecasting using Covalent"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6cd67d5f",
   "metadata": {},
   "source": [
    "## Summary\n",
    "In this example, we use Covalent to orchestrate a workflow which compares the performances of four different deep learning models -- the multi-layer perceptron (MLP) model, the convolutional neural network (CNN) model, the long short-term memory (LSTM) model, and a hybrid CNN-LSTM model. Each neural network is trained on daily historical sales data spanning four years. Covalent allows us to easily deploy different tasks to different hardware resources.\n",
    "\n",
    "<img src=\"https://raw.githubusercontent.com/dimitreOliveira/MachineLearning/master/Kaggle/Store%20Item%20Demand%20Forecasting%20Challenge/time-series%20graph.png\" width=\"800\">\n",
    "\n",
    "<img src=\"./images/workflow.gif\" width=\"800\">\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "929e3c37",
   "metadata": {},
   "source": [
    "## About Covalent\n",
    "The Covalent API is a Python module containing a small collection of classes that implement server-based workflow management. The key elements are two decorators that wrap functions to create managed **tasks** and **workflows**.\n",
    "\n",
    "The **task** decorator is called an `electron`. The `@electron` decorator simply turns the function into a dispatchable task.\n",
    "The **workflow** decorator is called a `lattice`. The `@lattice` decorator turns a function composed of electrons into a manageable workflow.\n",
    "\n",
    "### Covalent Features Used\n",
    "\n",
    "In order to transform our data and train our models we will be leveraging AWS resources to offload those tasks to remote machines with our desired amount of compute resources. Below are all of the features of covalent that we use for this tutorial.\n",
    "\n",
    "- To upload our dataset to an S3 bucket we will utilize Covalent's [FileTransfer](../../concepts/concepts.rst#files) functionallity\n",
    "- In order for covalent to be able to execute tasks in a remote environment we will need to declare which pip packages are to be installed for specific tasks using [DepsPip](../../concepts/concepts.rst#depspip)\n",
    "- We will be offloading the dataset transformation tasks to AWS ECS using the [AWS ECS Executor](../../api/executors/awsecs.rst)\n",
    "- Finally we will be training the models in an AWS Batch compute environment using the [AWS Batch Executor](../../api/executors/awsbatch.rst)  \n",
    "\n",
    "\n",
    "### Heterogeneity\n",
    "Tasks in this example are heterogeneous in several ways:\n",
    "\n",
    "- Task Scheduling: Dask / AWS ECS (Fargate) / AWS Batch\n",
    "- Execution Backend: Local / Remote CPU\n",
    "\n",
    "## Overview of Tutorial\n",
    "\n",
    "The workflow performs the following steps:\n",
    "\n",
    "|Task|Resources|\n",
    "|---|---|\n",
    "|1. Fetch and validate a training dataset from a remote machine|![](./images/dask.png)![](./images/computer.png)|\n",
    "|2. Explore the dataset and visualize broad trends|![](./images/dask.png)![](./images/computer.png)|\n",
    "|3. Clean / transform the data|![](./images/aws-1.png)![](./images/aws-2.png)|\n",
    "|4. Split the Data for Training and Validation|![](./images/dask.png)![](./images/computer.png)|\n",
    "|5. Construct the DNN Models|![](./images/dask.png)![](./images/computer.png)|\n",
    "|6. Train the Models|![](./images/slurm.png)![](./images/aws-3.png)![](./images/gpu.png)|\n",
    "|7. Visualize the results|![](./images/dask.png)![](./images/computer.png)|\n",
    "\n",
    "This tutorial is derived from work by Dimitre Oliveira on [Kaggle](https://www.kaggle.com/code/dimitreoliveira/deep-learning-for-time-series-forecasting/)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11f9fddf",
   "metadata": {},
   "source": [
    "In order for the workflow to succeed with the corresponding dataset, it's important to assign 2 GB of memory to each of the workers. For this workflow, Covalent was started with:\n",
    "```\n",
    "covalent start -n 4 -m \"2GB\"\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc42d673",
   "metadata": {},
   "source": [
    "## Implementation\n",
    "\n",
    "### Fetch and validate a training dataset from a remote machine\n",
    "\n",
    "We first import all of the libraries that are intended to be used"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a85435f3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "covalent\n",
      "covalent-aws-plugins[batch]\n",
      "covalent-aws-plugins[ecs]\n",
      "pandas==1.4.4\n",
      "tensorflow==2.9.1\n"
     ]
    }
   ],
   "source": [
    "with open(\"./requirements.txt\", \"r\") as file:\n",
    "    for line in file:\n",
    "        print(line.rstrip())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6f464dc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install required packages\n",
    "# !pip install -r ./requirements.txt\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6fc3f01",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Imports\n",
    "import covalent as ct\n",
    "\n",
    "import cloudpickle\n",
    "import warnings\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tensorflow.keras import optimizers\n",
    "from keras.utils.vis_utils import plot_model\n",
    "from keras.models import Sequential, Model\n",
    "from keras.layers.convolutional import Conv1D, MaxPooling1D\n",
    "from keras.layers import Dense, LSTM, RepeatVector, TimeDistributed, Flatten\n",
    "from sklearn.metrics import mean_squared_error\n",
    "from sklearn.model_selection import train_test_split\n",
    "import plotly.graph_objs as go\n",
    "from plotly.offline import init_notebook_mode, iplot\n",
    "from typing import List\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "matplotlib.use(\"Agg\")\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "init_notebook_mode(connected=True)\n",
    "\n",
    "from tensorflow.random import set_seed\n",
    "from numpy.random import seed\n",
    "\n",
    "set_seed(1)\n",
    "seed(1)\n",
    "\n",
    "from covalent._shared_files import logger\n",
    "\n",
    "app_log = logger.app_log\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d85081e",
   "metadata": {},
   "source": [
    "We construct the first task, which fetches input data from a remote machine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f861fc54",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data_bucket_uri = \"s3://dnn-tutorial-dataset\"\n",
    "strategy = ct.fs_strategies.S3(\n",
    "    credentials=\"/home/user/.aws/credentials\", profile=\"default\", region_name=\"us-east-1\"\n",
    ")\n",
    "\n",
    "\n",
    "@ct.electron(\n",
    "    files=[\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/train.csv\", order=ct.fs.Order.BEFORE, strategy=strategy\n",
    "        ),\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/test.csv\", order=ct.fs.Order.BEFORE, strategy=strategy\n",
    "        ),\n",
    "    ]\n",
    ")\n",
    "def fetch_and_validate_data(files=[]) -> List[str]:\n",
    "    _, local_train = files[0]\n",
    "    _, local_test = files[1]\n",
    "    train = pd.read_csv(local_train, parse_dates=[\"date\"])\n",
    "    test = pd.read_csv(local_test, parse_dates=[\"date\"])\n",
    "    return train, test\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "709104b2",
   "metadata": {},
   "source": [
    "### Explore the dataset and visualize broad trends\n",
    "Next, explore the dataset and visualize basic trends"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a6d2ca22",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "@ct.electron\n",
    "def explore_data(dataset):\n",
    "    # Aggregate daily sales\n",
    "    daily_sales = dataset.groupby(\"date\", as_index=False)[\"sales\"].sum()\n",
    "    store_daily_sales = dataset.groupby([\"store\", \"date\"], as_index=False)[\"sales\"].sum()\n",
    "    item_daily_sales = dataset.groupby([\"item\", \"date\"], as_index=False)[\"sales\"].sum()\n",
    "\n",
    "    # Plot total daily sales\n",
    "    daily_sales_sc = go.Scatter(x=daily_sales[\"date\"], y=daily_sales[\"sales\"])\n",
    "    layout1 = go.Layout(title=\"Daily sales\", xaxis=dict(title=\"Date\"), yaxis=dict(title=\"Sales\"))\n",
    "    fig1 = go.Figure(data=[daily_sales_sc], layout=layout1)\n",
    "\n",
    "    # Sales by item\n",
    "    item_daily_sales_sc = []\n",
    "    for item in item_daily_sales[\"item\"].unique():\n",
    "        current_item_daily_sales = item_daily_sales[(item_daily_sales[\"item\"] == item)]\n",
    "        item_daily_sales_sc.append(\n",
    "            go.Scatter(\n",
    "                x=current_item_daily_sales[\"date\"],\n",
    "                y=current_item_daily_sales[\"sales\"],\n",
    "                name=(\"Item %s\" % item),\n",
    "            )\n",
    "        )\n",
    "\n",
    "    layout2 = go.Layout(\n",
    "        title=\"Item daily sales\", xaxis=dict(title=\"Date\"), yaxis=dict(title=\"Sales\")\n",
    "    )\n",
    "    fig2 = go.Figure(data=item_daily_sales_sc, layout=layout2)\n",
    "\n",
    "    return fig1, fig2\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99476302",
   "metadata": {},
   "source": [
    "### Clean / transform the data\n",
    "Now we subsample and transform the data to a period of interest. We use the last year of data and for each date, use the previous month to forecast 90 days ahead.  In addition, this task will run on AWS Fargate on ECS.\n",
    "\n",
    "In order to run this electron on a remote executor, we need to specify the package dependcies via `ct.DepsPip(...)`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8fd3eb6d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "ecs_executor = ct.executor.ECSExecutor(\n",
    "    credentials=\"/home/user/.aws/credentials\",\n",
    "    region=\"us-east-1\",\n",
    "    profile=\"default\",\n",
    "    ecr_repo_name=\"dnn-tutorial-ecs-ecr-repo\",\n",
    "    ecs_cluster_name=\"dnn-tutorial-ecs-ecs-cluster\",\n",
    "    ecs_task_execution_role_name=\"dnn-tutorial-ecs-task-execution-role\",\n",
    "    ecs_task_log_group_name=\"dnn-tutorial-ecs-log-group\",\n",
    "    ecs_task_role_name=\"dnn-tutorial-ecs-task-role\",\n",
    "    ecs_task_security_group_id=\"sg-03f5c620fa4d4d5fc\",\n",
    "    ecs_task_subnet_id=\"subnet-0127aa9a0344152d0\",\n",
    "    s3_bucket_name=\"dnn-tutorial-ecs-bucket\",\n",
    "    vcpu=4,\n",
    "    memory=8,\n",
    "    retry_attempts=2,\n",
    "    poll_freq=10,\n",
    "    time_limit=3000,\n",
    "    cache_dir=\"/tmp/covalent\",\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0aeaa02b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "deps_pip = ct.DepsPip(packages=[\"pandas==1.4.4\"])\n",
    "\n",
    "\n",
    "@ct.electron(\n",
    "    deps_pip=deps_pip,\n",
    "    executor=ecs_executor,\n",
    "    files=[\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/train.csv\", order=ct.fs.Order.BEFORE, strategy=strategy\n",
    "        ),\n",
    "        ct.fs.TransferToRemote(\n",
    "            f\"{data_bucket_uri}/time_series.csv\", order=ct.fs.Order.AFTER, strategy=strategy\n",
    "        ),\n",
    "    ],\n",
    ")\n",
    "def transform_data(window, lag_size, files=[]):\n",
    "    _, local_train = files[0]\n",
    "    dataset = pd.read_csv(local_train)\n",
    "\n",
    "    # Subsample\n",
    "    refined_data = dataset[(dataset[\"date\"] >= \"2017-01-01\")]\n",
    "\n",
    "    # Restructure\n",
    "    group = refined_data.sort_values(\"date\").groupby([\"item\", \"store\", \"date\"], as_index=False)\n",
    "    group = group.agg({\"sales\": [\"mean\"]})\n",
    "    group.columns = [\"item\", \"store\", \"date\", \"sales\"]\n",
    "\n",
    "    data = group.drop(\"date\", axis=1)\n",
    "    cols, names = list(), list()\n",
    "\n",
    "    # Input sequence (t-n, ... t-1)\n",
    "    for i in range(window, 0, -1):\n",
    "        cols.append(data.shift(i))\n",
    "        names += [(\"%s(t-%d)\" % (col, i)) for col in data.columns]\n",
    "\n",
    "    # Current timestep (t=0)\n",
    "    cols.append(data)\n",
    "    names += [(\"%s(t)\" % (col)) for col in data.columns]\n",
    "\n",
    "    # Target timestep (t=lag)\n",
    "    cols.append(data.shift(-lag_size))\n",
    "    names += [(\"%s(t+%d)\" % (col, lag_size)) for col in data.columns]\n",
    "\n",
    "    # Put it all together\n",
    "    time_series = pd.concat(cols, axis=1)\n",
    "    time_series.columns = names\n",
    "\n",
    "    # Drop rows with NaN values\n",
    "    time_series.dropna(inplace=True)\n",
    "\n",
    "    # Filter unwanted data\n",
    "    last_item = \"item(t-%d)\" % window\n",
    "    last_store = \"store(t-%d)\" % window\n",
    "    time_series = time_series[(time_series[\"store(t)\"] == time_series[last_store])]\n",
    "    time_series = time_series[(time_series[\"item(t)\"] == time_series[last_item])]\n",
    "\n",
    "    columns_to_drop = [(\"%s(t+%d)\" % (col, lag_size)) for col in [\"item\", \"store\"]]\n",
    "    for i in range(window, 0, -1):\n",
    "        columns_to_drop += [(\"%s(t-%d)\" % (col, i)) for col in [\"item\", \"store\"]]\n",
    "    time_series.drop(columns_to_drop, axis=1, inplace=True)\n",
    "    time_series.drop([\"item(t)\", \"store(t)\"], axis=1, inplace=True)\n",
    "\n",
    "    app_log.debug(f\"Transformed size info: {time_series.shape}\")\n",
    "\n",
    "    from_file, to_file = files[1]\n",
    "    time_series.to_csv(from_file, index=False)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be0b1de7",
   "metadata": {},
   "source": [
    "### Split the Data for Training and Validation\n",
    "\n",
    "We store the data in the s3 bucket since they can potentially be very large and we want to avoid unnecessary pickling and unpickling during the workflow execution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "afd6f9d8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Split the time series into training / validation sets\n",
    "@ct.electron(\n",
    "    files=[\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/time_series.csv\", order=ct.fs.Order.BEFORE, strategy=strategy\n",
    "        ),\n",
    "        ct.fs.TransferToRemote(\n",
    "            f\"{data_bucket_uri}/X_train.pkl\", order=ct.fs.Order.AFTER, strategy=strategy\n",
    "        ),\n",
    "        ct.fs.TransferToRemote(\n",
    "            f\"{data_bucket_uri}/X_train_series.pkl\", order=ct.fs.Order.AFTER, strategy=strategy\n",
    "        ),\n",
    "        ct.fs.TransferToRemote(\n",
    "            f\"{data_bucket_uri}/X_valid.pkl\", order=ct.fs.Order.AFTER, strategy=strategy\n",
    "        ),\n",
    "        ct.fs.TransferToRemote(\n",
    "            f\"{data_bucket_uri}/X_valid_series.pkl\", order=ct.fs.Order.AFTER, strategy=strategy\n",
    "        ),\n",
    "        ct.fs.TransferToRemote(\n",
    "            f\"{data_bucket_uri}/Y_train.pkl\", order=ct.fs.Order.AFTER, strategy=strategy\n",
    "        ),\n",
    "        ct.fs.TransferToRemote(\n",
    "            f\"{data_bucket_uri}/Y_valid.pkl\", order=ct.fs.Order.AFTER, strategy=strategy\n",
    "        ),\n",
    "    ]\n",
    ")\n",
    "def split_data(lag_size, files=[]):\n",
    "    # Read file from s3 bucket\n",
    "    to_file, from_file = files[0]\n",
    "    time_series = pd.read_csv(from_file)\n",
    "\n",
    "    # Label\n",
    "    labels_col = \"sales(t+%d)\" % lag_size\n",
    "    labels = time_series[labels_col]\n",
    "    time_series = time_series.drop(labels_col, axis=1)\n",
    "\n",
    "    X_train, X_valid, Y_train, Y_valid = train_test_split(\n",
    "        time_series, labels.values, test_size=0.4, random_state=0\n",
    "    )\n",
    "    X_train_series = X_train.values.reshape((X_train.shape[0], X_train.shape[1], 1))\n",
    "    X_valid_series = X_valid.values.reshape((X_valid.shape[0], X_valid.shape[1], 1))\n",
    "\n",
    "    for idx, obj in enumerate(\n",
    "        [X_train, X_train_series, X_valid, X_valid_series, Y_train, Y_valid]\n",
    "    ):\n",
    "        from_file, to_file = files[idx + 1]\n",
    "        with open(from_file, \"wb\") as f:\n",
    "            cloudpickle.dump(obj, f)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd6b27ff",
   "metadata": {},
   "source": [
    "### Construct the DNN Models\n",
    "\n",
    "Next frame contruction of the four DNN models as tasks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e90bf2b5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "@ct.electron(\n",
    "    files=[\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/X_train.pkl\", order=ct.fs.Order.BEFORE, strategy=strategy\n",
    "        ),\n",
    "    ]\n",
    ")\n",
    "def construct_mlp_model(files=[]):\n",
    "    from_file, to_file = files[0]\n",
    "    with open(to_file, \"rb\") as f:\n",
    "        X_train = cloudpickle.load(f)\n",
    "    model_mlp = Sequential()\n",
    "    model_mlp.add(Dense(100, activation=\"relu\", input_dim=X_train.shape[1]))\n",
    "    model_mlp.add(Dense(1))\n",
    "\n",
    "    return model_mlp\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "fe634c14",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "@ct.electron(\n",
    "    files=[\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/X_train_series.pkl\", order=ct.fs.Order.BEFORE, strategy=strategy\n",
    "        ),\n",
    "    ]\n",
    ")\n",
    "def construct_cnn_model(files=[]):\n",
    "    with open(files[0][1], \"rb\") as f:\n",
    "        X_train_series = cloudpickle.load(f)\n",
    "\n",
    "    model_cnn = Sequential()\n",
    "    model_cnn.add(\n",
    "        Conv1D(\n",
    "            filters=64,\n",
    "            kernel_size=2,\n",
    "            activation=\"relu\",\n",
    "            input_shape=(X_train_series.shape[1], X_train_series.shape[2]),\n",
    "        )\n",
    "    )\n",
    "    model_cnn.add(MaxPooling1D(pool_size=2))\n",
    "    model_cnn.add(Flatten())\n",
    "    model_cnn.add(Dense(50, activation=\"relu\"))\n",
    "    model_cnn.add(Dense(1))\n",
    "\n",
    "    return model_cnn\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "487ebeb6",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "@ct.electron(\n",
    "    files=[\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/X_train_series.pkl\", order=ct.fs.Order.BEFORE, strategy=strategy\n",
    "        ),\n",
    "    ]\n",
    ")\n",
    "def construct_lstm_model(files=[]):\n",
    "    with open(files[0][1], \"rb\") as f:\n",
    "        X_train_series = cloudpickle.load(f)\n",
    "\n",
    "    model_lstm = Sequential()\n",
    "    model_lstm.add(\n",
    "        LSTM(50, activation=\"relu\", input_shape=(X_train_series.shape[1], X_train_series.shape[2]))\n",
    "    )\n",
    "    model_lstm.add(Dense(1))\n",
    "\n",
    "    return model_lstm\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9a23be52",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "@ct.electron(\n",
    "    files=[\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/X_train_series.pkl\", order=ct.fs.Order.BEFORE, strategy=strategy\n",
    "        ),\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/X_valid_series.pkl\", order=ct.fs.Order.BEFORE, strategy=strategy\n",
    "        ),\n",
    "    ]\n",
    ")\n",
    "def construct_cnn_lstm_model(files=[]):\n",
    "    with open(files[0][1], \"rb\") as f:\n",
    "        X_train_series = cloudpickle.load(f)\n",
    "\n",
    "    with open(files[1][1], \"rb\") as f:\n",
    "        X_valid_series = cloudpickle.load(f)\n",
    "\n",
    "    subsequences = 2\n",
    "    timesteps = X_train_series.shape[1] // subsequences\n",
    "    app_log.debug(\n",
    "        f\"Tensor shape info: {(X_train_series.shape[0], X_train_series.shape[1], timesteps)}\"\n",
    "    )\n",
    "    X_train_series_sub = X_train_series.reshape(\n",
    "        (X_train_series.shape[0], subsequences, timesteps, 1)\n",
    "    )\n",
    "    X_valid_series_sub = X_valid_series.reshape(\n",
    "        (X_valid_series.shape[0], subsequences, timesteps, 1)\n",
    "    )\n",
    "\n",
    "    model_cnn_lstm = Sequential()\n",
    "    model_cnn_lstm.add(\n",
    "        TimeDistributed(\n",
    "            Conv1D(filters=64, kernel_size=1, activation=\"relu\"),\n",
    "            input_shape=(None, X_train_series_sub.shape[2], X_train_series_sub.shape[3]),\n",
    "        )\n",
    "    )\n",
    "    model_cnn_lstm.add(TimeDistributed(MaxPooling1D(pool_size=2)))\n",
    "    model_cnn_lstm.add(TimeDistributed(Flatten()))\n",
    "    model_cnn_lstm.add(LSTM(50, activation=\"relu\"))\n",
    "    model_cnn_lstm.add(Dense(1))\n",
    "\n",
    "    return model_cnn_lstm\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a47ae10",
   "metadata": {},
   "source": [
    "### Train the Models\n",
    "\n",
    "Define tasks which train / fit the four models. We will offload these to a remote AWS Batch backend."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "e50bf9bd",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "batch_executor = ct.executor.AWSBatchExecutor(\n",
    "    credentials=\"/home/user/.aws/credentials\",\n",
    "    profile=\"default\",\n",
    "    batch_execution_role_name=\"dnn-tutorial-batch-task-execution-role\",\n",
    "    batch_job_definition_name=\"dnn-tutorial-batch-job-definition\",\n",
    "    batch_job_log_group_name=\"dnn-tutorial-batch-log-group\",\n",
    "    batch_job_role_name=\"dnn-tutorial-batch-job-role\",\n",
    "    batch_queue=\"dnn-tutorial-batch-queue\",\n",
    "    s3_bucket_name=\"dnn-tutorial-batch-job-resources\",\n",
    "    vcpu=2,\n",
    "    memory=4,\n",
    "    retry_attempts=2,\n",
    "    poll_freq=10,\n",
    "    time_limit=3000,\n",
    "    cache_dir=\"/tmp/covalent\",\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "9c21d0c2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "deps_pip = ct.DepsPip(packages=[\"tensorflow==2.9.1\", \"pandas==1.4.4\", \"cloudpickle==2.0.0\"])\n",
    "\n",
    "\n",
    "@ct.electron(\n",
    "    deps_pip=deps_pip,\n",
    "    executor=batch_executor,\n",
    "    files=[\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/X_train.pkl\", order=ct.fs.Order.AFTER, strategy=strategy\n",
    "        ),\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/X_valid.pkl\", order=ct.fs.Order.AFTER, strategy=strategy\n",
    "        ),\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/Y_train.pkl\", order=ct.fs.Order.AFTER, strategy=strategy\n",
    "        ),\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/Y_valid.pkl\", order=ct.fs.Order.AFTER, strategy=strategy\n",
    "        ),\n",
    "    ],\n",
    ")\n",
    "def fit_mlp_model(model_mlp, learning_rate, epochs, files=[]):\n",
    "    with open(files[0][1], \"rb\") as f:\n",
    "        X_train = cloudpickle.load(f)\n",
    "\n",
    "    with open(files[1][1], \"rb\") as f:\n",
    "        X_valid = cloudpickle.load(f)\n",
    "\n",
    "    with open(files[2][1], \"rb\") as f:\n",
    "        Y_train = cloudpickle.load(f)\n",
    "\n",
    "    with open(files[3][1], \"rb\") as f:\n",
    "        Y_valid = cloudpickle.load(f)\n",
    "\n",
    "    model_mlp.compile(loss=\"mse\", optimizer=optimizers.Adam(learning_rate))\n",
    "\n",
    "    return model_mlp.fit(\n",
    "        X_train.values, Y_train, validation_data=(X_valid.values, Y_valid), epochs=epochs\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "9f70f793",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "@ct.electron(\n",
    "    deps_pip=deps_pip,\n",
    "    executor=batch_executor,\n",
    "    files=[\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/X_train_series.pkl\", order=ct.fs.Order.AFTER, strategy=strategy\n",
    "        ),\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/X_valid_series.pkl\", order=ct.fs.Order.AFTER, strategy=strategy\n",
    "        ),\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/Y_train.pkl\", order=ct.fs.Order.AFTER, strategy=strategy\n",
    "        ),\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/Y_valid.pkl\", order=ct.fs.Order.AFTER, strategy=strategy\n",
    "        ),\n",
    "    ],\n",
    ")\n",
    "def fit_cnn_model(model_cnn, learning_rate, epochs, files=[]):\n",
    "    with open(files[0][1], \"rb\") as f:\n",
    "        X_train_series = cloudpickle.load(f)\n",
    "\n",
    "    with open(files[1][1], \"rb\") as f:\n",
    "        X_valid_series = cloudpickle.load(f)\n",
    "\n",
    "    with open(files[2][1], \"rb\") as f:\n",
    "        Y_train = cloudpickle.load(f)\n",
    "\n",
    "    with open(files[3][1], \"rb\") as f:\n",
    "        Y_valid = cloudpickle.load(f)\n",
    "\n",
    "    model_cnn.compile(loss=\"mse\", optimizer=optimizers.Adam(learning_rate))\n",
    "\n",
    "    return model_cnn.fit(\n",
    "        X_train_series, Y_train, validation_data=(X_valid_series, Y_valid), epochs=epochs\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "8750a4e2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "@ct.electron(\n",
    "    deps_pip=deps_pip,\n",
    "    executor=batch_executor,\n",
    "    files=[\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/X_train_series.pkl\", order=ct.fs.Order.AFTER, strategy=strategy\n",
    "        ),\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/X_valid_series.pkl\", order=ct.fs.Order.AFTER, strategy=strategy\n",
    "        ),\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/Y_train.pkl\", order=ct.fs.Order.AFTER, strategy=strategy\n",
    "        ),\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/Y_valid.pkl\", order=ct.fs.Order.AFTER, strategy=strategy\n",
    "        ),\n",
    "    ],\n",
    ")\n",
    "def fit_lstm_model(model_lstm, learning_rate, epochs, files=[]):\n",
    "    with open(files[0][1], \"rb\") as f:\n",
    "        X_train_series = cloudpickle.load(f)\n",
    "\n",
    "    with open(files[1][1], \"rb\") as f:\n",
    "        X_valid_series = cloudpickle.load(f)\n",
    "\n",
    "    with open(files[2][1], \"rb\") as f:\n",
    "        Y_train = cloudpickle.load(f)\n",
    "\n",
    "    with open(files[3][1], \"rb\") as f:\n",
    "        Y_valid = cloudpickle.load(f)\n",
    "\n",
    "    model_lstm.compile(loss=\"mse\", optimizer=optimizers.Adam(learning_rate))\n",
    "\n",
    "    return model_lstm.fit(\n",
    "        X_train_series, Y_train, validation_data=(X_valid_series, Y_valid), epochs=epochs\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "916e149c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "@ct.electron(\n",
    "    deps_pip=deps_pip,\n",
    "    executor=batch_executor,\n",
    "    files=[\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/X_train_series.pkl\", order=ct.fs.Order.AFTER, strategy=strategy\n",
    "        ),\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/X_valid_series.pkl\", order=ct.fs.Order.AFTER, strategy=strategy\n",
    "        ),\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/Y_train.pkl\", order=ct.fs.Order.AFTER, strategy=strategy\n",
    "        ),\n",
    "        ct.fs.TransferFromRemote(\n",
    "            f\"{data_bucket_uri}/Y_valid.pkl\", order=ct.fs.Order.AFTER, strategy=strategy\n",
    "        ),\n",
    "    ],\n",
    ")\n",
    "def fit_cnn_lstm_model(model_cnn_lstm, learning_rate, epochs, files=[]):\n",
    "    with open(files[0][1], \"rb\") as f:\n",
    "        X_train_series = cloudpickle.load(f)\n",
    "\n",
    "    with open(files[1][1], \"rb\") as f:\n",
    "        X_valid_series = cloudpickle.load(f)\n",
    "\n",
    "    with open(files[2][1], \"rb\") as f:\n",
    "        Y_train = cloudpickle.load(f)\n",
    "\n",
    "    with open(files[3][1], \"rb\") as f:\n",
    "        Y_valid = cloudpickle.load(f)\n",
    "\n",
    "    model_cnn_lstm.compile(loss=\"mse\", optimizer=optimizers.Adam(learning_rate))\n",
    "\n",
    "    subsequences = 2\n",
    "    timesteps = X_train_series.shape[1] // subsequences\n",
    "    X_train_series_sub = X_train_series.reshape(\n",
    "        (X_train_series.shape[0], subsequences, timesteps, 1)\n",
    "    )\n",
    "    X_valid_series_sub = X_valid_series.reshape(\n",
    "        (X_valid_series.shape[0], subsequences, timesteps, 1)\n",
    "    )\n",
    "\n",
    "    return model_cnn_lstm.fit(\n",
    "        X_train_series_sub, Y_train, validation_data=(X_valid_series_sub, Y_valid), epochs=epochs\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5507c61",
   "metadata": {},
   "source": [
    "### Visualize the results\n",
    "\n",
    "We then plot our loss fuctions with respect to epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "fe4aca3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "@ct.electron\n",
    "def visualize_loss_functions(mlp_history, cnn_history, lstm_history, cnn_lstm_history):    \n",
    "    fig, axes = plt.subplots(2, 2, sharex=True, sharey=False, figsize=(22,12))\n",
    "    ax1, ax2 = axes[0]\n",
    "    ax3, ax4 = axes[1]\n",
    "\n",
    "    ax1.plot(mlp_history.history['loss'], label='Train loss')\n",
    "    ax1.plot(mlp_history.history['val_loss'], label='Validation loss')\n",
    "    ax1.legend(loc='best')\n",
    "    ax1.set_title('MLP')\n",
    "    ax1.set_xlabel('Epochs')\n",
    "    ax1.set_ylabel('MSE')\n",
    "\n",
    "    ax2.plot(cnn_history.history['loss'], label='Train loss')\n",
    "    ax2.plot(cnn_history.history['val_loss'], label='Validation loss')\n",
    "    ax2.legend(loc='best')\n",
    "    ax2.set_title('CNN')\n",
    "    ax2.set_xlabel('Epochs')\n",
    "    ax2.set_ylabel('MSE')\n",
    "\n",
    "    ax3.plot(lstm_history.history['loss'], label='Train loss')\n",
    "    ax3.plot(lstm_history.history['val_loss'], label='Validation loss')\n",
    "    ax3.legend(loc='best')\n",
    "    ax3.set_title('LSTM')\n",
    "    ax3.set_xlabel('Epochs')\n",
    "    ax3.set_ylabel('MSE')\n",
    "\n",
    "    ax4.plot(cnn_lstm_history.history['loss'], label='Train loss')\n",
    "    ax4.plot(cnn_lstm_history.history['val_loss'], label='Validation loss')\n",
    "    ax4.legend(loc='best')\n",
    "    ax4.set_title('CNN-LSTM')\n",
    "    ax4.set_xlabel('Epochs')\n",
    "    ax4.set_ylabel('MSE')\n",
    "    plt.savefig(\"/tmp/comparison.png\")\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0426b8c8",
   "metadata": {},
   "source": [
    "### Compare Models\n",
    "We are now ready to construct a workflow which compares four DNN models using a variety of cloud-based compute resources"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "4e8c5f39",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "@ct.lattice\n",
    "def compare_models(window, lag_size, learning_rate, epochs):\n",
    "    # Fetch and validate source data\n",
    "    train, test = fetch_and_validate_data()\n",
    "\n",
    "    # Explore the data\n",
    "    daily_sales_plt, sales_by_item_plt = explore_data(train)\n",
    "\n",
    "    # Subsample and transform the data followed by splitting the dataset\n",
    "    transform_data_call = transform_data(window, lag_size)\n",
    "    split_data_call = split_data(lag_size)\n",
    "\n",
    "    ct.wait(child=split_data_call, parents=[transform_data_call])\n",
    "\n",
    "    # Construct the models\n",
    "    model_mlp = construct_mlp_model()\n",
    "    ct.wait(child=model_mlp, parents=[split_data_call])\n",
    "\n",
    "    model_cnn = construct_cnn_model()\n",
    "    ct.wait(child=model_cnn, parents=[split_data_call])\n",
    "\n",
    "    model_lstm = construct_lstm_model()\n",
    "    ct.wait(child=model_lstm, parents=[split_data_call])\n",
    "\n",
    "    model_cnn_lstm = construct_cnn_lstm_model()\n",
    "    ct.wait(child=model_cnn_lstm, parents=[split_data_call])\n",
    "\n",
    "    # Train the models\n",
    "    mlp_history = fit_mlp_model(model_mlp=model_mlp, learning_rate=learning_rate, epochs=epochs)\n",
    "    ct.wait(child=mlp_history, parents=[model_mlp])\n",
    "\n",
    "    cnn_history = fit_cnn_model(model_cnn=model_cnn, learning_rate=learning_rate, epochs=epochs)\n",
    "    ct.wait(child=cnn_history, parents=[model_cnn])\n",
    "\n",
    "    lstm_history = fit_lstm_model(\n",
    "        model_lstm=model_lstm, learning_rate=learning_rate, epochs=epochs\n",
    "    )\n",
    "    ct.wait(child=lstm_history, parents=[model_lstm])\n",
    "\n",
    "    cnn_lstm_history = fit_cnn_lstm_model(\n",
    "        model_cnn_lstm=model_cnn_lstm, learning_rate=learning_rate, epochs=epochs\n",
    "    )\n",
    "    ct.wait(child=cnn_lstm_history, parents=[model_cnn_lstm])\n",
    "\n",
    "    # Generate loss function plot\n",
    "    visualize_loss_functions(mlp_history, cnn_history, lstm_history, cnn_lstm_history)\n",
    "\n",
    "    return mlp_history, cnn_history, lstm_history, cnn_lstm_history\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cad27f33",
   "metadata": {},
   "source": [
    "### Run the workflow!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "819ce9f1",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1eeb141a-518c-4f9b-8aaf-10d46beca7de\n"
     ]
    }
   ],
   "source": [
    "dispatch_id = ct.dispatch(compare_models)(window=29, lag_size=90, learning_rate=0.0003, epochs=40)\n",
    "\n",
    "print(dispatch_id)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f855e406",
   "metadata": {},
   "source": [
    "### Last Steps\n",
    "\n",
    "You are now free to kill the Python kernel, take a walk, and come back when your workflow is done running.  Want reminders? Ask about e-mail or Slack notifications upon completion/failure...\n",
    "\n",
    "**But first, let's now take a look at the UI in the browser...**\n",
    "\n",
    "<img src=\"./images/workflow.png\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "7254c186",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "res = ct.get_result(dispatch_id, wait=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "90d901f4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Workflow status: COMPLETED \n",
      "Computation time: 62.3 minutes\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    f\"Workflow status: {res.status} \\nComputation time: {((res.end_time - res.start_time).total_seconds() / 60):.1f} minutes\"\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ea6e059",
   "metadata": {},
   "source": [
    "### Plots\n",
    "Finally, let's visualize the loss functions for each of the models as they trained"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a812d858",
   "metadata": {},
   "source": [
    "<img src=\"./images/comparison.png\">"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "new",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "65f23ff11413a1b24e6045f226bbb649c5d31fd62a4c12b34b399c99ac705181"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
